{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluating using ipynb\n",
    "\n",
    "While we provide evaluation scripts, if users are interested in using ipython notebooks to evaluate instead, this notebook should be useful."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from iterator import SmartIterator\n",
    "from utils.visualization_utils import get_att_map, objdict, get_dict\n",
    "from models import ReferringRelationshipsModel\n",
    "from utils.eval_utils import iou_bbox\n",
    "\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from keras import backend as K\n",
    "import numpy as np\n",
    "import os\n",
    "from PIL import Image\n",
    "import json\n",
    "import h5py\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import roc_auc_score\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Metrics\n",
    "Here are implementations for a bunch of different metrics. We don't use all of them in the paper but they are available."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "########### METRICS #########################################\n",
    "\n",
    "def sim_metric_np(y_true, y_pred, eps=10e-8):\n",
    "    y_true = (y_true.T/(eps + y_true.sum(axis=1).T)).T\n",
    "    y_pred = (y_pred.T/(eps + y_pred.sum(axis=1).T)).T\n",
    "    mini = ((y_true*(y_true<y_pred)) + (y_pred*(y_pred<y_true))).sum(axis=1)\n",
    "    return list(mini)\n",
    "\n",
    "def iou_np(y_true, y_pred, thresh=0.5, eps=10e-8):\n",
    "    y_pred = y_pred > thresh\n",
    "    intersection = (y_pred * y_true).sum(axis=1)\n",
    "    union = eps + ((y_pred + y_true)>0).sum(axis=1)\n",
    "    return list(intersection/union)\n",
    "\n",
    "def recall_np(y_true, y_pred, thresh=0.5, eps=10e-8):\n",
    "    y_pred = y_pred > thresh\n",
    "    tp = (y_pred * y_true).sum(axis=1)\n",
    "    fn = (1*((y_true - y_pred)>0)).sum(axis=1)\n",
    "    recall = tp/(tp+fn+eps)\n",
    "    return list(recall)\n",
    "\n",
    "def precision_np(y_true, y_pred, thresh=0.5, eps=10e-8):\n",
    "    y_pred = y_pred > thresh\n",
    "    tp = (y_pred * y_true).sum(axis=1)\n",
    "    p = y_pred.sum(axis=1)\n",
    "    prec = tp/(p+eps)\n",
    "    return list(prec)\n",
    "\n",
    "def kl_metric_np(y_true, y_pred, eps=10e-8):\n",
    "    y_true = (y_true.T/(eps + y_true.sum(axis=1).T)).T\n",
    "    y_pred = (y_pred.T/(eps + y_pred.sum(axis=1).T)).T\n",
    "    x = np.log(eps+(y_true/(eps+y_pred)))\n",
    "    return list((x*y_true).sum(axis=1))\n",
    "\n",
    "def cc_metric_np(y_true, y_pred, eps=10e-8):\n",
    "    sigma_true = y_true.var(axis=1)\n",
    "    sigma_pred = y_pred.var(axis=1)\n",
    "    cov = ((y_true-y_true.mean(axis=1, keepdims=True)) * (y_pred - y_pred.mean(axis=1, keepdims=True))).mean(axis=1)\n",
    "    return list(cov/np.sqrt((sigma_true*sigma_true)+eps))\n",
    "\n",
    "def iou_bbox_np(y_true, y_pred, thresh=0.5, eps=10e-8):\n",
    "    gt_bbox = get_bbox_from_heatmap(y_true, thresh)\n",
    "    pred_bbox = get_bbox_from_heatmap(y_pred, thresh)\n",
    "    return list(iou(gt_bbox, pred_bbox))\n",
    "\n",
    "def roc_auc(y_true, y_pred):\n",
    "    scores = []\n",
    "    for i in range(y_true.shape[0]):\n",
    "        if y_true[i].sum()>0:\n",
    "            scores += [roc_auc_score(y_true[i], y_pred[i], average='micro')]\n",
    "    return scores\n",
    "\n",
    "def pixel_acc(y_true, y_pred, thresh=0.5):\n",
    "    y_pred = y_pred > thresh\n",
    "    acc = (y_pred == y_true).mean(axis=1)\n",
    "    return list(acc)\n",
    "\n",
    "########### HELPERS #########################################\n",
    "\n",
    "def get_bbox_from_heatmap(heatmap, threshold, input_dim=224):\n",
    "    heatmap = heatmap.reshape((-1, input_dim, input_dim)) \n",
    "    heatmap[heatmap < threshold] = 0\n",
    "    horiz = 1. * (heatmap.sum(axis=2, keepdims=True)>0)\n",
    "    horiz = horiz.repeat(input_dim, axis=2)\n",
    "    vert = 1. * (heatmap.sum(axis=1, keepdims=True)>0)\n",
    "    vert = vert.repeat(input_dim, axis=1)\n",
    "    mask = horiz * vert\n",
    "    return mask\n",
    "\n",
    "def iou(y_true, y_pred, eps=10e-8):\n",
    "    intersection = (y_pred * y_true).sum(axis=1)\n",
    "    union = eps + ((y_pred + y_true)>0).sum(axis=1)\n",
    "    return intersection/union"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_model(model_checkpoint):\n",
    "    params = objdict(json.load(open(os.path.join(os.path.dirname(model_checkpoint), \"args.json\"), \"r\")))\n",
    "    params.finetune_cnn = True\n",
    "    model_weights = h5py.File(model_checkpoint)\n",
    "    relationships_model = ReferringRelationshipsModel(params)\n",
    "    model = relationships_model.build_model()\n",
    "    model.load_weights(model_checkpoint)\n",
    "    return model, params\n",
    "    \n",
    "\n",
    "def evaluate_model(model, params, test_data_dir, batch_size, metrics=[iou_np, recall_np, precision_np, kl_metric_np, cc_metric_np, sim_metric_np, pixel_acc]):\n",
    "    params.batch_size = batch_size\n",
    "    params.shuffle = False\n",
    "    test_generator = SmartIterator(test_data_dir, params)\n",
    "    results = {}\n",
    "    for metric in metrics:\n",
    "        results[metric.__name__+'_s'] = []\n",
    "        results[metric.__name__+'_o'] = []\n",
    "    for i in range(len(test_generator)):\n",
    "        if i%10 == 0:\n",
    "            print(\"{}/{}\".format(i, len(test_generator)))\n",
    "        batch_in, batch_out = test_generator[i]\n",
    "        preds = model.predict(batch_in)\n",
    "        for metric in metrics:\n",
    "            results[metric.__name__+'_s'] += metric(batch_out[0], preds[0])\n",
    "            results[metric.__name__+'_o'] += metric(batch_out[1], preds[1])\n",
    "    final= \"\"\n",
    "    for metric in metrics:\n",
    "        mean_s = np.mean(results[metric.__name__+'_s'])\n",
    "        mean_o = np.mean(results[metric.__name__+'_o'])\n",
    "        print(\"{} : {:.4f} & {:.4f} \".format(metric.__name__, mean_s, mean_o))\n",
    "        final += \" {:.4f} & {:.4f} \".format(mean_s, mean_o)\n",
    "    print(final)\n",
    "    return results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Let's load the pretrained model for VRD and evaluate it on the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_checkpoint = \"pretrained/vrd.h5\"\n",
    "model, params = load_model(model_checkpoint)\n",
    "results = evaluate_model(model, params, test_data_dir, batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluating per category\n",
    "In case you are interested in evaluating per category, the code below should help."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_model_per_cat(model, params, metrics=[iou_np, recall_np, precision_np, roc_auc]):\n",
    "    test_data_dir = \"/data/ranjaykrishna/ReferringRelationships/data/dataset-vrd-14/test\"\n",
    "    params.batch_size = 140\n",
    "    categories = json.load(open('data/VRD/objects.json', 'r'))\n",
    "    params.baseline_weights = None\n",
    "    params.shuffle = False\n",
    "    test_generator = SmartIterator(test_data_dir, params)\n",
    "    results = {}\n",
    "    for i in range(len(categories)):\n",
    "        results[i] = {}\n",
    "        for metric in metrics:\n",
    "            results[i][metric.__name__+'_s'] = []\n",
    "            results[i][metric.__name__+'_o'] = []\n",
    "    for i in range(len(test_generator)):\n",
    "        if i%10 == 0:\n",
    "            print(\"{}/{}\".format(i, len(test_generator)))\n",
    "        batch_in, batch_out = test_generator[i]\n",
    "        preds = model.predict(batch_in)\n",
    "        for j in range(len(categories)):\n",
    "            indices = batch_in[1] == j\n",
    "            if indices.sum() > 0:\n",
    "                sub_pred = [preds[0][indices.flatten()], preds[1][indices.flatten()]]\n",
    "                sub_batch_out = [batch_out[0][indices.flatten()], batch_out[1][indices.flatten()]]\n",
    "                for metric in metrics:\n",
    "                    results[j][metric.__name__+'_s'] += metric(sub_batch_out[0], sub_pred[0])\n",
    "            indices = batch_in[2] == j\n",
    "            if indices.sum() > 0:\n",
    "                sub_pred = [preds[0][indices.flatten()], preds[1][indices.flatten()]]\n",
    "                sub_batch_out = [batch_out[0][indices.flatten()], batch_out[1][indices.flatten()]]\n",
    "                for metric in metrics:\n",
    "                    results[j][metric.__name__+'_o'] += metric(sub_batch_out[1], sub_pred[1])\n",
    "    for i in results.keys():\n",
    "        for j in results[i].keys():\n",
    "            results[i][j] = np.mean(results[i][j])\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for cat in results_cat:\n",
    "    \"Subject IOU per cat\"\n",
    "    print(\"{} : {}\".format(categories[cat], results_cat[cat][\"precision_np_s\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for cat in results_cat:\n",
    "    \"Object IOU per cat\"\n",
    "    print(\"{} : {}\".format(categories[cat], results_cat[cat][\"precision_np_o\"]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
